/**
 * \file test/test_misc.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "lite_build_config.h"

#if LITE_BUILD_WITH_MGE
#include "../src/decryption/decrypt_base.h"
#include "../src/network_impl_base.h"
#include "test_common.h"

#include "megbrain/opr/io.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/metahelper.h"

#include <gtest/gtest.h>

#include <string.h>
#include <chrono>
#include <memory>
#include <random>

using namespace lite;

TEST(TestMisc, DecryptionRegister) {
    size_t number = decryption_static_data().decryption_methods.size();
    //! At least one method is register by lite
    ASSERT_GE(number, 1);
    DecryptionFunc func;
    register_decryption_and_key("AllForTest0", func, {});

    ASSERT_EQ(number + 1, decryption_static_data().decryption_methods.size());
}

TEST(TestMisc, DecryptionUpdate) {
    DecryptionFunc func;
    register_decryption_and_key("AllForTest1", func, {});
    func = [](const void*, size_t,
              const std::vector<uint8_t>&) -> std::vector<uint8_t> { return {}; };
    update_decryption_or_key("AllForTest1", func, {});
    ASSERT_NE(
            decryption_static_data().decryption_methods["AllForTest1"].first, nullptr);
    ASSERT_EQ(
            decryption_static_data().decryption_methods["AllForTest1"].second->size(),
            0);
    update_decryption_or_key("AllForTest1", {}, {1, 2, 3});
    ASSERT_EQ(
            decryption_static_data().decryption_methods["AllForTest1"].second->size(),
            3);
}

TEST(TestMisc, SharedSameDeviceTensor) {
    using namespace mgb;
    serialization::GraphLoader::LoadConfig mgb_config;
    mgb_config.comp_node_mapper = [](CompNode::Locator& loc) {
        loc = to_compnode_locator(LiteDeviceType::LITE_CPU);
    };
    mgb_config.comp_graph = ComputingGraph::make();
    std::string model_path = "./shufflenet.mge";

    auto inp_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
    auto format = serialization::GraphLoader::identify_graph_dump_format(*inp_file);
    mgb_assert(
            format.valid(),
            "invalid model: unknown model format, please make sure input "
            "file is generated by GraphDumper");
    auto loader = serialization::GraphLoader::make(std::move(inp_file), format.val());
    auto load_ret_1 = loader->load(mgb_config, true);
    auto load_ret_2 = loader->load(mgb_config, true);
    ASSERT_EQ(load_ret_1.output_var_list.size(), load_ret_2.output_var_list.size());

    ComputingGraph::OutputSpec out_spec_1, out_spec_2;
    for (size_t i = 0; i < load_ret_1.output_var_list.size(); i++) {
        out_spec_1.emplace_back(load_ret_1.output_var_list[i], nullptr);
        out_spec_2.emplace_back(load_ret_2.output_var_list[i], nullptr);
    }
    auto func_1 = load_ret_1.graph_compile(out_spec_1);
    auto func_2 = load_ret_2.graph_compile(out_spec_1);
    std::vector<cg::OperatorNodeBase*> oprs_1, oprs_2;
    func_1->iter_opr_seq([&oprs_1](cg::OperatorNodeBase* opr) -> bool {
        if (opr->try_cast_final<opr::ImmutableTensor>()) {
            oprs_1.push_back(opr);
        }
        return true;
    });
    func_1->iter_opr_seq([&oprs_2](cg::OperatorNodeBase* opr) -> bool {
        if (opr->try_cast_final<opr::ImmutableTensor>()) {
            oprs_2.push_back(opr);
        }
        return true;
    });
    ASSERT_EQ(oprs_1.size(), oprs_2.size());
    for (size_t i = 0; i < oprs_1.size(); i++) {
        auto tensor_1 = oprs_1[i]->try_cast_final<opr::ImmutableTensor>()->value();
        auto tensor_2 = oprs_2[i]->try_cast_final<opr::ImmutableTensor>()->value();
        ASSERT_EQ(tensor_1.raw_ptr(), tensor_2.raw_ptr());
    }
}

#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
